diff --git a/benchmark/generate_samples.py b/benchmark/generate_samples.py
index fdc9463..780741d 100644
--- a/benchmark/generate_samples.py
+++ b/benchmark/generate_samples.py
@@ -576,6 +576,7 @@ def main() -> None:
)
defer_imports_path.write_text(defer_imports_contents, encoding="utf-8")
+ # Same defer_imports-influenced imports, but for a test in the tests directory
tests_path = Path().resolve() / "tests" / "stdlib_imports.py"
tests_path.write_text(defer_imports_contents, encoding="utf-8")
diff --git a/src/defer_imports/_core.py b/src/defer_imports/_core.py
index 82355ba..324d57a 100644
--- a/src/defer_imports/_core.py
+++ b/src/defer_imports/_core.py
@@ -37,9 +37,9 @@ class DeferredInstrumenter(ast.NodeTransformer):
results are assigned to custom keys in the global namespace.
"""
- def __init__(self, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], data: SourceData, encoding: str) -> None:
- self.filepath = filepath
+ def __init__(self, data: SourceData, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], encoding: str) -> None:
self.data = data
+ self.filepath = filepath
self.encoding = encoding
self.scope_depth = 0
@@ -269,12 +269,6 @@ def visit_Module(self, node: ast.Module) -> ast.AST:
return self.generic_visit(node)
-def match_token(token: tokenize.TokenInfo, **kwargs: object) -> bool:
- """Check if a given token's attributes match the given kwargs."""
-
- return all(getattr(token, name) == val for name, val in kwargs.items())
-
-
def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Iterable[tuple[_tp.T, ...]]:
"""Collect data into overlapping fixed-length chunks or blocks.
@@ -306,10 +300,10 @@ def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tu
encoding = next(token_stream).string
uses_defer = any(
- match_token(tok1, type=tok_NAME, string="with")
- and match_token(tok2, type=tok_NAME, string="defer_imports")
- and match_token(tok3, type=tok_OP, string=".")
- and match_token(tok4, type=tok_NAME, string="until_use")
+ (tok1.type == tok_NAME and tok1.string == "with")
+ and (tok2.type == tok_NAME and tok2.string == "defer_imports")
+ and (tok3.type == tok_OP and tok3.string == ".")
+ and (tok4.type == tok_NAME and tok4.string == "until_use")
for tok1, tok2, tok3, tok4 in sliding_window(token_stream, 4)
)
@@ -351,7 +345,7 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
if not uses_defer:
return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above.
- tree = DeferredInstrumenter(path, data, encoding).instrument()
+ tree = DeferredInstrumenter(data, path, encoding).instrument()
return super().source_to_code(tree, path, _optimize=_optimize) # pyright: ignore # See note above.
def get_data(self, path: str) -> bytes:
diff --git a/src/defer_imports/_typing.py b/src/defer_imports/_typing.py
index ad5eb5d..9b1d8a9 100644
--- a/src/defer_imports/_typing.py
+++ b/src/defer_imports/_typing.py
@@ -70,7 +70,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
if name == "T":
from typing import TypeVar
- globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
+ globals()[name] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
return T
if name == "ReadableBuffer":
@@ -81,14 +81,14 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
ReadableBuffer = Union[bytes, bytearray, memoryview]
- globals()["ReadableBuffer"] = ReadableBuffer
+ globals()[name] = ReadableBuffer
return ReadableBuffer
if name == "StrPath":
import os
from typing import Union
- globals()["StrPath"] = StrPath = Union[str, os.PathLike[str]]
+ globals()[name] = StrPath = Union[str, os.PathLike[str]]
return StrPath
if name == "TypeAlias":
@@ -99,7 +99,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
class TypeAlias:
"""Placeholder for typing.TypeAlias."""
- globals()["TypeAlias"] = TypeAlias
+ globals()[name] = TypeAlias
return TypeAlias
msg = f"module {__name__!r} has no attribute {name!r}"
diff --git a/src/defer_imports/console.py b/src/defer_imports/console.py
index b987c44..676f30a 100644
--- a/src/defer_imports/console.py
+++ b/src/defer_imports/console.py
@@ -38,7 +38,7 @@ def runsource(self, source: str, filename: str = "", symbol: str = "singl
# Case 3: Input is complete.
try:
- tree = DeferredInstrumenter(filename, source, "utf-8").instrument(symbol)
+ tree = DeferredInstrumenter(source, filename, "utf-8").instrument(symbol)
code = compile(tree, filename, symbol)
except SyntaxError:
# Case 1, again.
diff --git a/tests/test_deferred.py b/tests/test_deferred.py
index 08f249f..b8705f1 100644
--- a/tests/test_deferred.py
+++ b/tests/test_deferred.py
@@ -175,7 +175,7 @@ def test_instrumentation(before: str, after: str):
before_bytes = before.encode()
encoding, _ = tokenize.detect_encoding(io.BytesIO(before_bytes).readline)
- transformed_tree = DeferredInstrumenter("", before_bytes, encoding).instrument()
+ transformed_tree = DeferredInstrumenter(before_bytes, "", encoding).instrument()
assert f"{ast.unparse(transformed_tree)}\n" == after