diff --git a/benchmark/generate_samples.py b/benchmark/generate_samples.py index f9b79b8..a1c9d7a 100644 --- a/benchmark/generate_samples.py +++ b/benchmark/generate_samples.py @@ -576,7 +576,7 @@ def main() -> None: ) deferred_path.write_text(deferred_contents, encoding="utf-8") - tests_path = Path().resolve() / "tests" / "sample_deferred.py" + tests_path = Path().resolve() / "tests" / "stdlib_imports.py" tests_path.write_text(deferred_contents, encoding="utf-8") # slothy-hooked imports diff --git a/pyproject.toml b/pyproject.toml index c3c6a46..ca52cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,9 +181,8 @@ keep-runtime-typing = true "T203", # Pretty printing is fine. "ANN201", # Don't need return annotations for tests. "S102", # exec is used to test for NameError within a module's namespace. - "F401", # Unused imports are fine; we're testing import success. ] -"tests/sample_deferred.py" = [ +"tests/stdlib_imports.py" = [ "F401", # Unused imports are fine; we're testing import success. "ERA001", # Plenty of imports are commented out with explanation in comment. ] diff --git a/src/deferred/_core.py b/src/deferred/_core.py index 73c64c6..4906a61 100644 --- a/src/deferred/_core.py +++ b/src/deferred/_core.py @@ -76,7 +76,7 @@ def _decode_source(self) -> str: """Get the source code corresponding to the given data.""" if isinstance(self.data, ast.AST): - # NOTE: An attempt is made here; node location information likely won't match up. + # NOTE: An attempt is made here, but the node location information likely won't match up. return ast.unparse(self.data) elif isinstance(self.data, str): # noqa: RET505 # Readability return self.data @@ -177,6 +177,23 @@ def _substitute_import_keys(self, import_nodes: list[ast.stmt]) -> list[ast.stmt return new_import_nodes + @staticmethod + def check_With_for_defer_usage(node: ast.With) -> bool: + return len(node.items) == 1 and ( + ( + # Allow "with defer_imports_until_use". + isinstance(node.items[0].context_expr, ast.Name) + and node.items[0].context_expr.id == "defer_imports_until_use" + ) + or ( + # Allow "with deferred.defer_imports_until_use". + isinstance(node.items[0].context_expr, ast.Attribute) + and isinstance(node.items[0].context_expr.value, ast.Name) + and node.items[0].context_expr.value.id == "deferred" + and node.items[0].context_expr.attr == "defer_imports_until_use" + ) + ) + def visit_With(self, node: ast.With) -> ast.AST: """Check that "with defer_imports_until_use" blocks are valid and if so, hook all imports within. @@ -189,23 +206,7 @@ def visit_With(self, node: ast.With) -> ast.AST: 3. "defer_imports_until_use" block contains a wildcard import. """ - if not ( - len(node.items) == 1 - and ( - ( - # Allow "with defer_imports_until_use". - isinstance(node.items[0].context_expr, ast.Name) - and node.items[0].context_expr.id == "defer_imports_until_use" - ) - or ( - # Allow "with deferred.defer_imports_until_use". - isinstance(node.items[0].context_expr, ast.Attribute) - and isinstance(node.items[0].context_expr.value, ast.Name) - and node.items[0].context_expr.value.id == "deferred" - and node.items[0].context_expr.attr == "defer_imports_until_use" - ) - ) - ): + if not self.check_With_for_defer_usage(node): return self.generic_visit(node) if self.scope_depth != 0: @@ -322,22 +323,7 @@ def check_ast_for_defer_usage(data: ast.AST) -> tuple[str, bool]: """Check if the given AST uses "with defer_imports_until_use". Also assume "utf-8" is the the encoding.""" uses_defer = any( - isinstance(node, ast.With) - and len(node.items) == 1 - and ( - ( - # Allow "with defer_imports_until_use". - isinstance(node.items[0].context_expr, ast.Name) - and node.items[0].context_expr.id == "defer_imports_until_use" - ) - or ( - # Allow "with deferred.defer_imports_until_use". - isinstance(node.items[0].context_expr, ast.Attribute) - and isinstance(node.items[0].context_expr.value, ast.Name) - and node.items[0].context_expr.value.id == "deferred" - and node.items[0].context_expr.attr == "defer_imports_until_use" - ) - ) + isinstance(node, ast.With) and DeferredInstrumenter.check_With_for_defer_usage(node) for node in ast.walk(data) ) encoding = "utf-8" @@ -494,7 +480,7 @@ class DeferredImportKey(str): When referenced, the key will replace itself in the namespace with the resolved import or the right name from it. """ - __slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing") + __slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_lock") def __new__(cls, key: str, proxy: DeferredImportProxy, /): return super().__new__(cls, key) @@ -532,6 +518,8 @@ def _resolve(self) -> None: # Perform the original __import__ and pray. module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) + # FIXME: The below print is enough to trigger a race condition. + # print(f"{module=}, {proxy.defer_proxy_name=}, {proxy.defer_proxy_fromlist=}, {proxy.defer_proxy_sub=}") # noqa: ERA001 # Transfer nested proxies over to the resolved module. module_vars = vars(module) diff --git a/src/deferred/_typing.py b/src/deferred/_typing.py index 30bc101..e8df442 100644 --- a/src/deferred/_typing.py +++ b/src/deferred/_typing.py @@ -48,62 +48,32 @@ def final(f: object) -> object: return f -def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0912, PLR0915 +def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911 # Let's cache the return values in the global namespace to avoid subsequent calls to __getattr__ if possible. - if name == "T": - from typing import TypeVar - - globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues] - return T - - if name == "Any": - from typing import Any - - globals()["Any"] = Any - return Any - - if name == "CodeType": - from types import CodeType - - globals()["CodeType"] = CodeType - return CodeType - - if name == "Final": - from typing import Final - - globals()["Final"] = Final - return Final - - if name == "Generator": - from collections.abc import Generator + if name in {"Generator", "Iterable", "MutableMapping", "Sequence"}: + import collections.abc - globals()["Generator"] = Generator - return Generator + globals()[name] = res = getattr(collections.abc, name) + return res - if name == "Iterable": - from collections.abc import Iterable + if name in {"Any", "Final", "Optional", "Union"}: + import typing - globals()["Iterable"] = Iterable - return Iterable + globals()[name] = res = getattr(typing, name) + return res - if name == "ModuleType": - from types import ModuleType + if name in {"CodeType", "ModuleType"}: + import types - globals()["ModuleType"] = ModuleType - return ModuleType + globals()[name] = res = getattr(types, name) + return res - if name == "MutableMapping": - from collections.abc import MutableMapping - - globals()["MutableMapping"] = MutableMapping - return MutableMapping - - if name == "Optional": - from typing import Optional + if name == "T": + from typing import TypeVar - globals()["Optional"] = Optional - return Optional + globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues] + return T if name == "ReadableBuffer": if sys.version_info >= (3, 12): @@ -116,12 +86,6 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0 globals()["ReadableBuffer"] = ReadableBuffer return ReadableBuffer - if name == "Sequence": - from collections.abc import Sequence - - globals()["Sequence"] = Sequence - return Sequence - if name == "StrPath": import os from typing import Union @@ -140,11 +104,5 @@ class TypeAlias: globals()["TypeAlias"] = TypeAlias return TypeAlias - if name == "Union": - from typing import Union - - globals()["Union"] = Union - return Union - msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) diff --git a/tests/sample_deferred.py b/tests/stdlib_imports.py similarity index 100% rename from tests/sample_deferred.py rename to tests/stdlib_imports.py diff --git a/tests/test_deferred.py b/tests/test_deferred.py index 459e240..b274626 100644 --- a/tests/test_deferred.py +++ b/tests/test_deferred.py @@ -434,6 +434,30 @@ def test_from_import_with_rename(tmp_path: Path): assert module.MySignature is sys.modules["inspect"].Signature +def test_deferred_header_in_instrumented_pycache(tmp_path: Path): + """Test that the deferred-specific bytecode header is being prepended to the bytecode cache files of + deferred-instrumented modules. + """ + + source = """\ +from deferred import defer_imports_until_use + +with defer_imports_until_use: + import asyncio +""" + + spec, module, path = create_sample_module(tmp_path, source, DeferredFileLoader) + assert spec.loader + spec.loader.exec_module(module) + + expected_cache = Path(importlib.util.cache_from_source(str(path))) + assert expected_cache.is_file() + + with expected_cache.open("rb") as fp: + header = fp.read(len(BYTECODE_HEADER)) + assert header == BYTECODE_HEADER + + def test_error_if_non_import(tmp_path: Path): source = """\ from deferred import defer_imports_until_use @@ -807,7 +831,7 @@ def run(self) -> None: # pragma: no cover finally: del self._target, self._args, self._kwargs # pyright: ignore - def access_module_attr() -> Any: + def access_module_attr() -> object: time.sleep(0.2) return module.inspect.signature @@ -821,38 +845,20 @@ def access_module_attr() -> Any: for thread in threads: thread.join() assert callable(thread.result) # pyright: ignore - # FIXME: Spurious error in pypy3.10 only; sometimes the accessed signature isn't resolved? + # FIXME: Error on various versions; sometimes the accessed signature isn't resolved? + # Recreatable by putting "print(module)" after line 525: + # module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) # noqa: ERA001 + # + # The error: # AssertionError: assert False # + where False = callable() # + where = .result assert thread.exc is None -def test_deferred_header_in_instrumented_pycache(tmp_path: Path): - """Test that the deferred-specific bytecode header is being prepended to the bytecode cache files of - deferred-instrumented modules. - """ - - source = """\ -from deferred import defer_imports_until_use - -with defer_imports_until_use: - import asyncio -""" - - spec, module, path = create_sample_module(tmp_path, source, DeferredFileLoader) - assert spec.loader - spec.loader.exec_module(module) - - expected_cache = Path(importlib.util.cache_from_source(str(path))) - assert expected_cache.is_file() - - with expected_cache.open("rb") as fp: - header = fp.read(len(BYTECODE_HEADER)) - assert header == BYTECODE_HEADER - - def test_import_stdlib(): """Test that we can import most of the stdlib.""" - import tests.sample_deferred # pyright: ignore [reportUnusedImport] + import tests.stdlib_imports + + assert tests.stdlib_imports