Skip to content

Commit

Permalink
Make DeferredFileLoader.source_to_code able to handle ASTs and strs a…
Browse files Browse the repository at this point in the history
…s input.
  • Loading branch information
Sachaa-Thanasius committed Aug 29, 2024
1 parent adb2d9e commit 5707c6c
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 21 deletions.
19 changes: 10 additions & 9 deletions benchmark/bench_samples.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pyright: reportUnusedImport=none
"""Simple benchark script for comparing the import time of the Python standard library when using regular imports,
deferred-influence imports, and slothy-influenced imports.
The sample scripts being imported are generated with benchmark/generate_samples.py.
"""

import platform
Expand Down Expand Up @@ -93,12 +95,11 @@ def main() -> None:
exec_order = args.exec_order or list(BENCH_FUNCS)

# Perform benchmarking.
# TODO: Investigate how to make multiple iterations work.
# TODO: Investigate how to make multiple iterations work. Seems like caching is unavoidable.
results = {type_: BENCH_FUNCS[type_]() for type_ in exec_order}
minimum = min(results.values())

# Format and print outcomes.

# Format and print results as an reST-style list table.
impl_header = "Implementation"
impl_len = len(impl_header)
impl_divider = "=" * impl_len
Expand All @@ -120,8 +121,9 @@ def main() -> None:
else:
print("Run once with bytecode caches allowed")

print()
divider = " ".join((impl_divider, version_divider, benchmark_divider, time_divider))

print()
print(divider)
print(impl_header, version_header, benchmark_header, time_header, sep=" ")
print(divider)
Expand All @@ -131,12 +133,11 @@ def main() -> None:

for bench_type, result in results.items():
formatted_result = f"{result:.5f}s ({result / minimum:.2f}x)"

print(
f"{impl:{impl_len}}",
f"{version:{version_len}}",
f"{bench_type:{benchmark_len}}",
f"{formatted_result:{time_len}}",
impl.ljust(impl_len),
version.ljust(version_len),
bench_type.ljust(benchmark_len),
formatted_result.ljust(time_len),
sep=" ",
)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ extend-ignore = [
"ISC001",
"ISC002",
# ---- Project-specific rules
"SIM108", # Ternaries can be less readable than multiline if-else.
]
unfixable = [
"ERA", # Prevent unlikely erroneous deletion.
Expand Down
73 changes: 61 additions & 12 deletions src/deferred/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DeferredInstrumenter(ast.NodeTransformer):
def __init__(
self,
filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer],
data: _tp.ReadableBuffer,
data: _tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive],
encoding: str,
) -> None:
self.filepath = filepath
Expand All @@ -51,7 +51,12 @@ def __init__(
def instrument(self) -> _tp.Any:
"""Transform the tree created from the given data and filepath."""

return ast.fix_missing_locations(self.visit(ast.parse(self.data, self.filepath, "exec")))
if isinstance(self.data, ast.AST):
to_visit = self.data
else:
to_visit = ast.parse(self.data, self.filepath, "exec")

return ast.fix_missing_locations(self.visit(to_visit))

def _visit_scope(self, node: ast.AST) -> ast.AST:
"""Track Python scope changes. Used to determine if defer_imports_until_use usage is valid."""
Expand All @@ -74,7 +79,14 @@ def _decode_source(self) -> str:
def _get_node_context(self, node: ast.stmt): # noqa: ANN202 # Version-dependent and too verbose.
"""Get the location context for a node. That context will be used as an argument to SyntaxError."""

text = ast.get_source_segment(self._decode_source(), node, padded=True)
if isinstance(self.data, ast.AST):
source = ast.unparse(self.data)
elif isinstance(self.data, str):
source = self.data
else:
source = self._decode_source()

text = ast.get_source_segment(source, node, padded=True)
context = (str(self.filepath), node.lineno, node.col_offset + 1, text)
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
end_col_offset = (node.end_col_offset + 1) if (node.end_col_offset is not None) else None
Expand Down Expand Up @@ -272,21 +284,28 @@ class DeferredFileLoader(SourceFileLoader):
"""A file loader that instruments .py files which use "with defer_imports_until_use: ..."."""

@staticmethod
def check_source_for_defer_usage(data: _tp.ReadableBuffer) -> tuple[str, bool]:
def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tuple[str, bool]:
"""Get the encoding of the given code and also check if it uses "with defer_imports_until_use"."""

tok_NAME, tok_OP = tokenize.NAME, tokenize.OP

token_stream = tokenize.tokenize(io.BytesIO(data).readline)
encoding = next(token_stream).string
if isinstance(data, str):
token_stream = tokenize.generate_tokens(io.StringIO(data).readline)
encoding = "utf-8"
else:
token_stream = tokenize.tokenize(io.BytesIO(data).readline)
encoding = next(token_stream).string

uses_defer = any(
match_token(tok1, type=tok_NAME, string="with")
and (
(
# Allow "with defer_imports_until_use".
match_token(tok2, type=tok_NAME, string="defer_imports_until_use")
and match_token(tok3, type=tok_OP, string=":")
)
or (
# Allow "with deferred.defer_imports_until_use".
match_token(tok2, type=tok_NAME, string="deferred")
and match_token(tok3, type=tok_OP, string=".")
and match_token(tok4, type=tok_NAME, string="defer_imports_until_use")
Expand All @@ -297,26 +316,56 @@ def check_source_for_defer_usage(data: _tp.ReadableBuffer) -> tuple[str, bool]:

return encoding, uses_defer

@staticmethod
def check_ast_for_defer_usage(data: ast.AST) -> tuple[str, bool]:
"""Check if the given AST uses "with defer_imports_until_use". Also assume "utf-8" is the the encoding."""

uses_defer = any(
node
for node in ast.walk(data)
if isinstance(node, ast.With)
and len(node.items) == 1
and (
(
# Allow "with defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Name)
and node.items[0].context_expr.id == "defer_imports_until_use"
)
or (
# Allow "with deferred.defer_imports_until_use".
isinstance(node.items[0].context_expr, ast.Attribute)
and isinstance(node.items[0].context_expr.value, ast.Name)
and node.items[0].context_expr.value.id == "deferred"
and node.items[0].context_expr.attr == "defer_imports_until_use"
)
)
)
encoding = "utf-8"
return encoding, uses_defer

def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride]
self,
data: _tp.ReadableBuffer,
path: _tp.Union[_tp.ReadableBuffer, _tp.StrPath],
data: _tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive],
path: _tp.Union[_tp.StrPath, _tp.ReadableBuffer],
*,
_optimize: int = -1,
) -> _tp.CodeType:
# NOTE: InspectLoader is the virtual superclass of SourceFileLoader thanks to ABC registration, so typeshed
# reflects that. However, there's a slight mismatch in source_to_code signatures. Make a PR?
# reflects that. However, there's some mismatch in source_to_code signatures. Can it be fixed with a PR?

# Fall back to regular importlib machinery if the module is empty or doesn't use defer_imports_until_use.
if not data:
return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above.

encoding, uses_defer = self.check_source_for_defer_usage(data)
if isinstance(data, ast.AST):
encoding, uses_defer = self.check_ast_for_defer_usage(data)
else:
encoding, uses_defer = self.check_source_for_defer_usage(data)

if not uses_defer:
return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above.

tree = DeferredInstrumenter(path, data, encoding).instrument()
return compile(tree, path, "exec", dont_inherit=True, optimize=_optimize)
return super().source_to_code(tree, path, _optimize=_optimize) # pyright: ignore # See note above.

def get_data(self, path: str) -> bytes:
"""Return the data from path as raw bytes.
Expand Down

0 comments on commit 5707c6c

Please sign in to comment.