diff --git a/benchmark/generate_samples.py b/benchmark/generate_samples.py index fdc9463..780741d 100644 --- a/benchmark/generate_samples.py +++ b/benchmark/generate_samples.py @@ -576,6 +576,7 @@ def main() -> None: ) defer_imports_path.write_text(defer_imports_contents, encoding="utf-8") + # Same defer_imports-influenced imports, but for a test in the tests directory tests_path = Path().resolve() / "tests" / "stdlib_imports.py" tests_path.write_text(defer_imports_contents, encoding="utf-8") diff --git a/src/defer_imports/_core.py b/src/defer_imports/_core.py index 82355ba..324d57a 100644 --- a/src/defer_imports/_core.py +++ b/src/defer_imports/_core.py @@ -37,9 +37,9 @@ class DeferredInstrumenter(ast.NodeTransformer): results are assigned to custom keys in the global namespace. """ - def __init__(self, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], data: SourceData, encoding: str) -> None: - self.filepath = filepath + def __init__(self, data: SourceData, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], encoding: str) -> None: self.data = data + self.filepath = filepath self.encoding = encoding self.scope_depth = 0 @@ -269,12 +269,6 @@ def visit_Module(self, node: ast.Module) -> ast.AST: return self.generic_visit(node) -def match_token(token: tokenize.TokenInfo, **kwargs: object) -> bool: - """Check if a given token's attributes match the given kwargs.""" - - return all(getattr(token, name) == val for name, val in kwargs.items()) - - def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Iterable[tuple[_tp.T, ...]]: """Collect data into overlapping fixed-length chunks or blocks. @@ -306,10 +300,10 @@ def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tu encoding = next(token_stream).string uses_defer = any( - match_token(tok1, type=tok_NAME, string="with") - and match_token(tok2, type=tok_NAME, string="defer_imports") - and match_token(tok3, type=tok_OP, string=".") - and match_token(tok4, type=tok_NAME, string="until_use") + (tok1.type == tok_NAME and tok1.string == "with") + and (tok2.type == tok_NAME and tok2.string == "defer_imports") + and (tok3.type == tok_OP and tok3.string == ".") + and (tok4.type == tok_NAME and tok4.string == "until_use") for tok1, tok2, tok3, tok4 in sliding_window(token_stream, 4) ) @@ -351,7 +345,7 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] if not uses_defer: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. - tree = DeferredInstrumenter(path, data, encoding).instrument() + tree = DeferredInstrumenter(data, path, encoding).instrument() return super().source_to_code(tree, path, _optimize=_optimize) # pyright: ignore # See note above. def get_data(self, path: str) -> bytes: diff --git a/src/defer_imports/_typing.py b/src/defer_imports/_typing.py index ad5eb5d..9b1d8a9 100644 --- a/src/defer_imports/_typing.py +++ b/src/defer_imports/_typing.py @@ -70,7 +70,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911 if name == "T": from typing import TypeVar - globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues] + globals()[name] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues] return T if name == "ReadableBuffer": @@ -81,14 +81,14 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911 ReadableBuffer = Union[bytes, bytearray, memoryview] - globals()["ReadableBuffer"] = ReadableBuffer + globals()[name] = ReadableBuffer return ReadableBuffer if name == "StrPath": import os from typing import Union - globals()["StrPath"] = StrPath = Union[str, os.PathLike[str]] + globals()[name] = StrPath = Union[str, os.PathLike[str]] return StrPath if name == "TypeAlias": @@ -99,7 +99,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911 class TypeAlias: """Placeholder for typing.TypeAlias.""" - globals()["TypeAlias"] = TypeAlias + globals()[name] = TypeAlias return TypeAlias msg = f"module {__name__!r} has no attribute {name!r}" diff --git a/src/defer_imports/console.py b/src/defer_imports/console.py index b987c44..676f30a 100644 --- a/src/defer_imports/console.py +++ b/src/defer_imports/console.py @@ -38,7 +38,7 @@ def runsource(self, source: str, filename: str = "", symbol: str = "singl # Case 3: Input is complete. try: - tree = DeferredInstrumenter(filename, source, "utf-8").instrument(symbol) + tree = DeferredInstrumenter(source, filename, "utf-8").instrument(symbol) code = compile(tree, filename, symbol) except SyntaxError: # Case 1, again. diff --git a/tests/test_deferred.py b/tests/test_deferred.py index 08f249f..b8705f1 100644 --- a/tests/test_deferred.py +++ b/tests/test_deferred.py @@ -175,7 +175,7 @@ def test_instrumentation(before: str, after: str): before_bytes = before.encode() encoding, _ = tokenize.detect_encoding(io.BytesIO(before_bytes).readline) - transformed_tree = DeferredInstrumenter("", before_bytes, encoding).instrument() + transformed_tree = DeferredInstrumenter(before_bytes, "", encoding).instrument() assert f"{ast.unparse(transformed_tree)}\n" == after