Skip to content

Commit

Permalink
Switch around some parameters of DeferredInstrumenter.__init__.
Browse files Browse the repository at this point in the history
Adjust arguments everywhere accordingly.
  • Loading branch information
Sachaa-Thanasius committed Sep 5, 2024
1 parent 8dda090 commit 4addf45
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 19 deletions.
1 change: 1 addition & 0 deletions benchmark/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def main() -> None:
)
defer_imports_path.write_text(defer_imports_contents, encoding="utf-8")

# Same defer_imports-influenced imports, but for a test in the tests directory
tests_path = Path().resolve() / "tests" / "stdlib_imports.py"
tests_path.write_text(defer_imports_contents, encoding="utf-8")

Expand Down
20 changes: 7 additions & 13 deletions src/defer_imports/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class DeferredInstrumenter(ast.NodeTransformer):
results are assigned to custom keys in the global namespace.
"""

def __init__(self, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], data: SourceData, encoding: str) -> None:
self.filepath = filepath
def __init__(self, data: SourceData, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], encoding: str) -> None:
self.data = data
self.filepath = filepath
self.encoding = encoding
self.scope_depth = 0

Expand Down Expand Up @@ -269,12 +269,6 @@ def visit_Module(self, node: ast.Module) -> ast.AST:
return self.generic_visit(node)


def match_token(token: tokenize.TokenInfo, **kwargs: object) -> bool:
"""Check if a given token's attributes match the given kwargs."""

return all(getattr(token, name) == val for name, val in kwargs.items())


def sliding_window(iterable: _tp.Iterable[_tp.T], n: int) -> _tp.Iterable[tuple[_tp.T, ...]]:
"""Collect data into overlapping fixed-length chunks or blocks.
Expand Down Expand Up @@ -306,10 +300,10 @@ def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tu
encoding = next(token_stream).string

uses_defer = any(
match_token(tok1, type=tok_NAME, string="with")
and match_token(tok2, type=tok_NAME, string="defer_imports")
and match_token(tok3, type=tok_OP, string=".")
and match_token(tok4, type=tok_NAME, string="until_use")
(tok1.type == tok_NAME and tok1.string == "with")
and (tok2.type == tok_NAME and tok2.string == "defer_imports")
and (tok3.type == tok_OP and tok3.string == ".")
and (tok4.type == tok_NAME and tok4.string == "until_use")
for tok1, tok2, tok3, tok4 in sliding_window(token_stream, 4)
)

Expand Down Expand Up @@ -351,7 +345,7 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
if not uses_defer:
return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above.

tree = DeferredInstrumenter(path, data, encoding).instrument()
tree = DeferredInstrumenter(data, path, encoding).instrument()
return super().source_to_code(tree, path, _optimize=_optimize) # pyright: ignore # See note above.

def get_data(self, path: str) -> bytes:
Expand Down
8 changes: 4 additions & 4 deletions src/defer_imports/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
if name == "T":
from typing import TypeVar

globals()["T"] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
globals()[name] = T = TypeVar("T") # pyright: ignore [reportGeneralTypeIssues]
return T

if name == "ReadableBuffer":
Expand All @@ -81,14 +81,14 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911

ReadableBuffer = Union[bytes, bytearray, memoryview]

globals()["ReadableBuffer"] = ReadableBuffer
globals()[name] = ReadableBuffer
return ReadableBuffer

if name == "StrPath":
import os
from typing import Union

globals()["StrPath"] = StrPath = Union[str, os.PathLike[str]]
globals()[name] = StrPath = Union[str, os.PathLike[str]]
return StrPath

if name == "TypeAlias":
Expand All @@ -99,7 +99,7 @@ def __getattr__(name: str) -> object: # pragma: no cover # noqa: PLR0911
class TypeAlias:
"""Placeholder for typing.TypeAlias."""

globals()["TypeAlias"] = TypeAlias
globals()[name] = TypeAlias
return TypeAlias

msg = f"module {__name__!r} has no attribute {name!r}"
Expand Down
2 changes: 1 addition & 1 deletion src/defer_imports/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def runsource(self, source: str, filename: str = "<input>", symbol: str = "singl

# Case 3: Input is complete.
try:
tree = DeferredInstrumenter(filename, source, "utf-8").instrument(symbol)
tree = DeferredInstrumenter(source, filename, "utf-8").instrument(symbol)
code = compile(tree, filename, symbol)
except SyntaxError:
# Case 1, again.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_instrumentation(before: str, after: str):

before_bytes = before.encode()
encoding, _ = tokenize.detect_encoding(io.BytesIO(before_bytes).readline)
transformed_tree = DeferredInstrumenter("<unknown>", before_bytes, encoding).instrument()
transformed_tree = DeferredInstrumenter(before_bytes, "<unknown>", encoding).instrument()

assert f"{ast.unparse(transformed_tree)}\n" == after

Expand Down

0 comments on commit 4addf45

Please sign in to comment.