diff --git a/src/defer_imports/_core.py b/src/defer_imports/_core.py index 6008dea..4b53e76 100644 --- a/src/defer_imports/_core.py +++ b/src/defer_imports/_core.py @@ -320,15 +320,6 @@ def check_ast_for_defer_usage(data: ast.AST) -> tuple[str, bool]: class DeferredFileLoader(SourceFileLoader): """A file loader that instruments .py files which use "with defer_imports.until_use: ...".""" - @staticmethod - def check_for_defer_usage(data: SourceData) -> tuple[str, bool]: - """Check if the given data uses "with defer_imports.until_use".""" - - if isinstance(data, ast.AST): - return check_ast_for_defer_usage(data) - else: - return check_source_for_defer_usage(data) - def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] self, data: SourceData, @@ -342,12 +333,16 @@ def source_to_code( # pyright: ignore [reportIncompatibleMethodOverride] if not data: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. - encoding, uses_defer = self.check_for_defer_usage(data) + # Check if the given data uses "with defer_imports.until_use". + if isinstance(data, ast.AST): + encoding, uses_defer = check_ast_for_defer_usage(data) + else: + encoding, uses_defer = check_source_for_defer_usage(data) if not uses_defer: return super().source_to_code(data, path, _optimize=_optimize) # pyright: ignore # See note above. - # Get the AST of the given data, instrument it, and fix missing line and column numbers. + # Instrument the AST of the given data. transformer = DeferredInstrumenter(data, path, encoding) if isinstance(data, ast.AST): diff --git a/src/defer_imports/console.py b/src/defer_imports/console.py index 5a573f7..201400f 100644 --- a/src/defer_imports/console.py +++ b/src/defer_imports/console.py @@ -74,7 +74,7 @@ def __init__(self): self.actual_transformer = DeferredInstrumenter("", "", "utf-8") def visit(self, node: ast.AST) -> _tp.Any: - # Reset part of the wrapped transformer before use. + # Reset part of the wrapped transformer before (re)use. self.actual_transformer.data = node self.actual_transformer.scope_depth = 0 return ast.fix_missing_locations(self.actual_transformer.visit(node)) @@ -84,12 +84,17 @@ def instrument_ipython() -> None: """Add defer_import's compile-time AST transformer to a currently running IPython environment. This will ensure that defer_imports.until_use works as intended when used directly in a IPython console. + + Raises + ------ + RuntimeError + If called in a non-IPython environment. """ try: ipython_shell: _tp.Any = get_ipython() # pyright: ignore except NameError: - msg = "Not currently in an IPython/Jupyter environment." + msg = "Not currently in an IPython environment." raise RuntimeError(msg) from None ipython_shell.ast_transformers.append(_DeferredIPythonInstrumenter())