diff --git a/benchmark/bench_samples.py b/benchmark/bench_samples.py index 1713e19..b136a03 100644 --- a/benchmark/bench_samples.py +++ b/benchmark/bench_samples.py @@ -1,6 +1,8 @@ # pyright: reportUnusedImport=none """Simple benchark script for comparing the import time of the Python standard library when using regular imports, deferred-influence imports, and slothy-influenced imports. + +The sample scripts being imported are generated with benchmark/generate_samples.py. """ import platform @@ -93,12 +95,11 @@ def main() -> None: exec_order = args.exec_order or list(BENCH_FUNCS) # Perform benchmarking. - # TODO: Investigate how to make multiple iterations work. + # TODO: Investigate how to make multiple iterations work. Seems like caching is unavoidable. results = {type_: BENCH_FUNCS[type_]() for type_ in exec_order} minimum = min(results.values()) - # Format and print outcomes. - + # Format and print results as an reST-style list table. impl_header = "Implementation" impl_len = len(impl_header) impl_divider = "=" * impl_len @@ -120,8 +121,9 @@ def main() -> None: else: print("Run once with bytecode caches allowed") - print() divider = " ".join((impl_divider, version_divider, benchmark_divider, time_divider)) + + print() print(divider) print(impl_header, version_header, benchmark_header, time_header, sep=" ") print(divider) @@ -131,12 +133,11 @@ def main() -> None: for bench_type, result in results.items(): formatted_result = f"{result:.5f}s ({result / minimum:.2f}x)" - print( - f"{impl:{impl_len}}", - f"{version:{version_len}}", - f"{bench_type:{benchmark_len}}", - f"{formatted_result:{time_len}}", + impl.ljust(impl_len), + version.ljust(version_len), + bench_type.ljust(benchmark_len), + formatted_result.ljust(time_len), sep=" ", ) diff --git a/pyproject.toml b/pyproject.toml index 541d951..a0222ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ extend-ignore = [ "ISC001", "ISC002", # ---- Project-specific rules + "SIM108", # Ternaries can be less readable than multiline if-else. ] unfixable = [ "ERA", # Prevent unlikely erroneous deletion. diff --git a/src/deferred/_core.py b/src/deferred/_core.py index ebe6787..e933673 100644 --- a/src/deferred/_core.py +++ b/src/deferred/_core.py @@ -40,7 +40,7 @@ class DeferredInstrumenter(ast.NodeTransformer): def __init__( self, filepath: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], - data: _tp.ReadableBuffer, + data: _tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive], encoding: str, ) -> None: self.filepath = filepath @@ -51,7 +51,12 @@ def __init__( def instrument(self) -> _tp.Any: """Transform the tree created from the given data and filepath.""" - return ast.fix_missing_locations(self.visit(ast.parse(self.data, self.filepath, "exec"))) + if isinstance(self.data, ast.AST): + to_visit = self.data + else: + to_visit = ast.parse(self.data, self.filepath, "exec") + + return ast.fix_missing_locations(self.visit(to_visit)) def _visit_scope(self, node: ast.AST) -> ast.AST: """Track Python scope changes. Used to determine if defer_imports_until_use usage is valid.""" @@ -74,7 +79,14 @@ def _decode_source(self) -> str: def _get_node_context(self, node: ast.stmt): # noqa: ANN202 # Version-dependent and too verbose. """Get the location context for a node. That context will be used as an argument to SyntaxError.""" - text = ast.get_source_segment(self._decode_source(), node, padded=True) + if isinstance(self.data, ast.AST): + source = ast.unparse(self.data) + elif isinstance(self.data, str): + source = self.data + else: + source = self._decode_source() + + text = ast.get_source_segment(source, node, padded=True) context = (str(self.filepath), node.lineno, node.col_offset + 1, text) if sys.version_info >= (3, 10): # pragma: >=3.10 cover end_col_offset = (node.end_col_offset + 1) if (node.end_col_offset is not None) else None @@ -272,21 +284,28 @@ class DeferredFileLoader(SourceFileLoader): """A file loader that instruments .py files which use "with defer_imports_until_use: ...".""" @staticmethod - def check_source_for_defer_usage(data: _tp.ReadableBuffer) -> tuple[str, bool]: + def check_source_for_defer_usage(data: _tp.Union[_tp.ReadableBuffer, str]) -> tuple[str, bool]: """Get the encoding of the given code and also check if it uses "with defer_imports_until_use".""" tok_NAME, tok_OP = tokenize.NAME, tokenize.OP - token_stream = tokenize.tokenize(io.BytesIO(data).readline) - encoding = next(token_stream).string + if isinstance(data, str): + token_stream = tokenize.generate_tokens(io.StringIO(data).readline) + encoding = "utf-8" + else: + token_stream = tokenize.tokenize(io.BytesIO(data).readline) + encoding = next(token_stream).string + uses_defer = any( match_token(tok1, type=tok_NAME, string="with") and ( ( + # Allow "with defer_imports_until_use". match_token(tok2, type=tok_NAME, string="defer_imports_until_use") and match_token(tok3, type=tok_OP, string=":") ) or ( + # Allow "with deferred.defer_imports_until_use". match_token(tok2, type=tok_NAME, string="deferred") and match_token(tok3, type=tok_OP, string=".") and match_token(tok4, type=tok_NAME, string="defer_imports_until_use") @@ -297,26 +316,56 @@ def check_source_for_defer_usage(data: _tp.ReadableBuffer) -> tuple[str, bool]: return encoding, uses_defer + @staticmethod + def check_ast_for_defer_usage(data: ast.AST) -> tuple[str, bool]: + """Check if the given AST uses "with defer_imports_until_use". Also assume "utf-8" is the the encoding.""" + + uses_defer = any( + node + for node in ast.walk(data) + if isinstance(node, ast.With) + and len(node.items) == 1 + and ( + ( + # Allow "with defer_imports_until_use". + isinstance(node.items[0].context_expr, ast.Name) + and node.items[0].context_expr.id == "defer_imports_until_use" + ) + or ( + # Allow "with deferred.defer_imports_until_use". + isinstance(node.items[0].context_expr, ast.Attribute) + and isinstance(node.items[0].context_expr.value, ast.Name) + and node.items[0].context_expr.value.id == "deferred" + and node.items[0].context_expr.attr == "defer_imports_until_use" + ) + ) + ) + encoding = "utf-8" + return encoding, uses_defer + def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] self, - data: _tp.ReadableBuffer, - path: _tp.Union[_tp.ReadableBuffer, _tp.StrPath], + data: _tp.Union[_tp.ReadableBuffer, str, ast.Module, ast.Expression, ast.Interactive], + path: _tp.Union[_tp.StrPath, _tp.ReadableBuffer], *, _optimize: int = -1, ) -> _tp.CodeType: # NOTE: InspectLoader is the virtual superclass of SourceFileLoader thanks to ABC registration, so typeshed - # reflects that. However, there's a slight mismatch in source_to_code signatures. Make a PR? + # reflects that. However, there's some mismatch in source_to_code signatures. Can it be fixed with a PR? - # Fall back to regular importlib machinery if the module is empty or doesn't use defer_imports_until_use. if not data: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. - encoding, uses_defer = self.check_source_for_defer_usage(data) + if isinstance(data, ast.AST): + encoding, uses_defer = self.check_ast_for_defer_usage(data) + else: + encoding, uses_defer = self.check_source_for_defer_usage(data) + if not uses_defer: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. tree = DeferredInstrumenter(path, data, encoding).instrument() - return compile(tree, path, "exec", dont_inherit=True, optimize=_optimize) + return super().source_to_code(tree, path, _optimize=_optimize) # pyright: ignore # See note above. def get_data(self, path: str) -> bytes: """Return the data from path as raw bytes.