Skip to content

Commit

Permalink
Add import of most of stdlib to tests, and make sure it's generated.
Browse files Browse the repository at this point in the history
- Investigate race condition more.
- Simplify _typing __getattr__
- Move ast defer check into a staticmethod to be called in other places.
  • Loading branch information
Sachaa-Thanasius committed Aug 30, 2024
1 parent 9515a76 commit 1bd1b9a
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 124 deletions.
2 changes: 1 addition & 1 deletion benchmark/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def main() -> None:
)
deferred_path.write_text(deferred_contents, encoding="utf-8")

tests_path = Path().resolve() / "tests" / "sample_deferred.py"
tests_path = Path().resolve() / "tests" / "stdlib_imports.py"
tests_path.write_text(deferred_contents, encoding="utf-8")

# slothy-hooked imports
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,8 @@ keep-runtime-typing = true
"T203", # Pretty printing is fine.
"ANN201", # Don't need return annotations for tests.
"S102", # exec is used to test for NameError within a module's namespace.
"F401", # Unused imports are fine; we're testing import success.
]
"tests/sample_deferred.py" = [
"tests/stdlib_imports.py" = [
"F401", # Unused imports are fine; we're testing import success.
"ERA001", # Plenty of imports are commented out with explanation in comment.
]
Expand Down
58 changes: 23 additions & 35 deletions src/deferred/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _decode_source(self) -> str:
"""Get the source code corresponding to the given data."""

if isinstance(self.data, ast.AST):
# NOTE: An attempt is made here; node location information likely won't match up.
# NOTE: An attempt is made here, but the node location information likely won't match up.
return ast.unparse(self.data)
elif isinstance(self.data, str): # noqa: RET505 # Readability
return self.data
Expand Down Expand Up @@ -177,6 +177,23 @@ def _substitute_import_keys(self, import_nodes: list[ast.stmt]) -> list[ast.stmt

return new_import_nodes

@staticmethod
def check_With_for_defer_usage(node: ast.With) -> bool:
return len(node.items) == 1 and (
(
# Allow "with defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Name)
and node.items[0].context_expr.id == "defer_imports_until_use"
)
or (
# Allow "with deferred.defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Attribute)
and isinstance(node.items[0].context_expr.value, ast.Name)
and node.items[0].context_expr.value.id == "deferred"
and node.items[0].context_expr.attr == "defer_imports_until_use"
)
)

def visit_With(self, node: ast.With) -> ast.AST:
"""Check that "with defer_imports_until_use" blocks are valid and if so, hook all imports within.
Expand All @@ -189,23 +206,7 @@ def visit_With(self, node: ast.With) -> ast.AST:
3. "defer_imports_until_use" block contains a wildcard import.
"""

if not (
len(node.items) == 1
and (
(
# Allow "with defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Name)
and node.items[0].context_expr.id == "defer_imports_until_use"
)
or (
# Allow "with deferred.defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Attribute)
and isinstance(node.items[0].context_expr.value, ast.Name)
and node.items[0].context_expr.value.id == "deferred"
and node.items[0].context_expr.attr == "defer_imports_until_use"
)
)
):
if not self.check_With_for_defer_usage(node):
return self.generic_visit(node)

if self.scope_depth != 0:
Expand Down Expand Up @@ -322,22 +323,7 @@ def check_ast_for_defer_usage(data: ast.AST) -> tuple[str, bool]:
"""Check if the given AST uses "with defer_imports_until_use". Also assume "utf-8" is the the encoding."""

uses_defer = any(
isinstance(node, ast.With)
and len(node.items) == 1
and (
(
# Allow "with defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Name)
and node.items[0].context_expr.id == "defer_imports_until_use"
)
or (
# Allow "with deferred.defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Attribute)
and isinstance(node.items[0].context_expr.value, ast.Name)
and node.items[0].context_expr.value.id == "deferred"
and node.items[0].context_expr.attr == "defer_imports_until_use"
)
)
isinstance(node, ast.With) and DeferredInstrumenter.check_With_for_defer_usage(node)
for node in ast.walk(data)
)
encoding = "utf-8"
Expand Down Expand Up @@ -494,7 +480,7 @@ class DeferredImportKey(str):
When referenced, the key will replace itself in the namespace with the resolved import or the right name from it.
"""

__slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing")
__slots__ = ("defer_key_str", "defer_key_proxy", "is_recursing", "_lock")

def __new__(cls, key: str, proxy: DeferredImportProxy, /):
return super().__new__(cls, key)
Expand Down Expand Up @@ -532,6 +518,8 @@ def _resolve(self) -> None:

# Perform the original __import__ and pray.
module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args)
# FIXME: The below print is enough to trigger a race condition.
# print(f"{module=}, {proxy.defer_proxy_name=}, {proxy.defer_proxy_fromlist=}, {proxy.defer_proxy_sub=}") # noqa: ERA001

# Transfer nested proxies over to the resolved module.
module_vars = vars(module)
Expand Down
76 changes: 17 additions & 59 deletions src/deferred/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,62 +48,32 @@ def final(f: object) -> object:
return f


def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0912, PLR0915
def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
# Let's cache the return values in the global namespace to avoid subsequent calls to __getattr__ if possible.

if name == "T":
from typing import TypeVar

globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
return T

if name == "Any":
from typing import Any

globals()["Any"] = Any
return Any

if name == "CodeType":
from types import CodeType

globals()["CodeType"] = CodeType
return CodeType

if name == "Final":
from typing import Final

globals()["Final"] = Final
return Final

if name == "Generator":
from collections.abc import Generator
if name in {"Generator", "Iterable", "MutableMapping", "Sequence"}:
import collections.abc

globals()["Generator"] = Generator
return Generator
globals()[name] = res = getattr(collections.abc, name)
return res

if name == "Iterable":
from collections.abc import Iterable
if name in {"Any", "Final", "Optional", "Union"}:
import typing

globals()["Iterable"] = Iterable
return Iterable
globals()[name] = res = getattr(typing, name)
return res

if name == "ModuleType":
from types import ModuleType
if name in {"CodeType", "ModuleType"}:
import types

globals()["ModuleType"] = ModuleType
return ModuleType
globals()[name] = res = getattr(types, name)
return res

if name == "MutableMapping":
from collections.abc import MutableMapping

globals()["MutableMapping"] = MutableMapping
return MutableMapping

if name == "Optional":
from typing import Optional
if name == "T":
from typing import TypeVar

globals()["Optional"] = Optional
return Optional
globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
return T

if name == "ReadableBuffer":
if sys.version_info >= (3, 12):
Expand All @@ -116,12 +86,6 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911, PLR0
globals()["ReadableBuffer"] = ReadableBuffer
return ReadableBuffer

if name == "Sequence":
from collections.abc import Sequence

globals()["Sequence"] = Sequence
return Sequence

if name == "StrPath":
import os
from typing import Union
Expand All @@ -140,11 +104,5 @@ class TypeAlias:
globals()["TypeAlias"] = TypeAlias
return TypeAlias

if name == "Union":
from typing import Union

globals()["Union"] = Union
return Union

msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
File renamed without changes.
60 changes: 33 additions & 27 deletions tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,30 @@ def test_from_import_with_rename(tmp_path: Path):
assert module.MySignature is sys.modules["inspect"].Signature


def test_deferred_header_in_instrumented_pycache(tmp_path: Path):
"""Test that the deferred-specific bytecode header is being prepended to the bytecode cache files of
deferred-instrumented modules.
"""

source = """\
from deferred import defer_imports_until_use
with defer_imports_until_use:
import asyncio
"""

spec, module, path = create_sample_module(tmp_path, source, DeferredFileLoader)
assert spec.loader
spec.loader.exec_module(module)

expected_cache = Path(importlib.util.cache_from_source(str(path)))
assert expected_cache.is_file()

with expected_cache.open("rb") as fp:
header = fp.read(len(BYTECODE_HEADER))
assert header == BYTECODE_HEADER


def test_error_if_non_import(tmp_path: Path):
source = """\
from deferred import defer_imports_until_use
Expand Down Expand Up @@ -807,7 +831,7 @@ def run(self) -> None: # pragma: no cover
finally:
del self._target, self._args, self._kwargs # pyright: ignore

def access_module_attr() -> Any:
def access_module_attr() -> object:
time.sleep(0.2)
return module.inspect.signature

Expand All @@ -821,38 +845,20 @@ def access_module_attr() -> Any:
for thread in threads:
thread.join()
assert callable(thread.result) # pyright: ignore
# FIXME: Spurious error in pypy3.10 only; sometimes the accessed signature isn't resolved?
# FIXME: Error on various versions; sometimes the accessed signature isn't resolved?
# Recreatable by putting "print(module)" after line 525:
# module: _tp.ModuleType = original_import.get()(*proxy.defer_proxy_import_args) # noqa: ERA001
#
# The error:
# AssertionError: assert False
# + where False = callable(<proxy for 'import inspect as ...'>)
# + where <proxy for 'import inspect as ...'> = <CapturingThread(Thread 4, stopped 22024)>.result
assert thread.exc is None


def test_deferred_header_in_instrumented_pycache(tmp_path: Path):
"""Test that the deferred-specific bytecode header is being prepended to the bytecode cache files of
deferred-instrumented modules.
"""

source = """\
from deferred import defer_imports_until_use
with defer_imports_until_use:
import asyncio
"""

spec, module, path = create_sample_module(tmp_path, source, DeferredFileLoader)
assert spec.loader
spec.loader.exec_module(module)

expected_cache = Path(importlib.util.cache_from_source(str(path)))
assert expected_cache.is_file()

with expected_cache.open("rb") as fp:
header = fp.read(len(BYTECODE_HEADER))
assert header == BYTECODE_HEADER


def test_import_stdlib():
"""Test that we can import most of the stdlib."""

import tests.sample_deferred # pyright: ignore [reportUnusedImport]
import tests.stdlib_imports

assert tests.stdlib_imports

0 comments on commit 1bd1b9a

Please sign in to comment.