diff --git a/src/codemodder/project_analysis/python_repo_manager.py b/src/codemodder/project_analysis/python_repo_manager.py index c73e4656..36332dfd 100644 --- a/src/codemodder/project_analysis/python_repo_manager.py +++ b/src/codemodder/project_analysis/python_repo_manager.py @@ -26,5 +26,7 @@ def package_stores(self) -> list[PackageStore]: def _parse_all_stores(self) -> list[PackageStore]: discovered_pkg_stores: list[PackageStore] = [] for store in self._potential_stores: - discovered_pkg_stores.extend(store(self.parent_directory).parse()) + discovered_pkg_stores.extend( + store(self.parent_directory).parse() # type: ignore + ) return discovered_pkg_stores diff --git a/src/core_codemods/remove_unused_imports.py b/src/core_codemods/remove_unused_imports.py index 6d3c43c1..1b54b1a1 100644 --- a/src/core_codemods/remove_unused_imports.py +++ b/src/core_codemods/remove_unused_imports.py @@ -1,4 +1,4 @@ -from libcst import ensure_type, matchers +from libcst import CSTVisitor, ensure_type, matchers from libcst.codemod.visitors import GatherUnusedImportsVisitor from libcst.metadata import ( PositionProvider, @@ -20,7 +20,7 @@ import re from pylint.utils.pragma_parser import parse_pragma -NOQA_PATTERN = re.compile(r"^#\s*noqa") +NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE) class RemoveUnusedImports(BaseCodemod, Codemod): @@ -45,6 +45,9 @@ def __init__(self, codemod_context: CodemodContext, *codemod_args): BaseCodemod.__init__(self, *codemod_args) def transform_module_impl(self, tree: cst.Module) -> cst.Module: + # Do nothing in __init__.py files + if self.file_context.file_path.name == "__init__.py": + return tree gather_unused_visitor = GatherUnusedImportsVisitor(self.context) tree.visit(gather_unused_visitor) # filter the gathered imports by line excludes/includes @@ -78,18 +81,23 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool: if parent and matchers.matches(parent, matchers.SimpleStatementLine()): stmt = ensure_type(parent, cst.SimpleStatementLine) - # has a trailing comment string - trailing_comment_string = ( - stmt.trailing_whitespace.comment.value - if stmt.trailing_whitespace.comment - else None - ) - if trailing_comment_string and NOQA_PATTERN.match(trailing_comment_string): - return True - if trailing_comment_string and _is_pylint_disable_unused_imports( - trailing_comment_string - ): - return True + # has a trailing comment string anywhere in the node + comments_visitor = GatherCommentNodes() + stmt.body[0].visit(comments_visitor) + # has a trailing comment string anywhere in the node + if stmt.trailing_whitespace.comment: + comments_visitor.comments.append(stmt.trailing_whitespace.comment) + + for comment in comments_visitor.comments: + trailing_comment_string = comment.value + if trailing_comment_string and NOQA_PATTERN.match( + trailing_comment_string + ): + return True + if trailing_comment_string and _is_pylint_disable_unused_imports( + trailing_comment_string + ): + return True # has a comment right above it if matchers.matches( @@ -111,25 +119,42 @@ def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool: return False +class GatherCommentNodes(CSTVisitor): + def __init__(self) -> None: + self.comments: list[cst.Comment] = [] + super().__init__() + + def leave_Comment(self, original_node: cst.Comment) -> None: + self.comments.append(original_node) + + def match_line(pos, line): return pos.start.line == line and pos.end.line == line def _is_pylint_disable_unused_imports(comment: str) -> bool: - parsed = parse_pragma(comment) - for p in parsed: - if p.action == "disable" and ( - "unused-import" in p.messages or "W0611" in p.messages - ): - return True + # If pragma parse fails, ignore + try: + parsed = parse_pragma(comment) + for p in parsed: + if p.action == "disable" and ( + "unused-import" in p.messages or "W0611" in p.messages + ): + return True + except Exception: + pass return False def _is_pylint_disable_next_unused_imports(comment: str) -> bool: - parsed = parse_pragma(comment) - for p in parsed: - if p.action == "disable-next" and ( - "unused-import" in p.messages or "W0611" in p.messages - ): - return True + # If pragma parse fails, ignore + try: + parsed = parse_pragma(comment) + for p in parsed: + if p.action == "disable-next" and ( + "unused-import" in p.messages or "W0611" in p.messages + ): + return True + except Exception: + pass return False diff --git a/tests/codemods/base_codemod_test.py b/tests/codemods/base_codemod_test.py index 5cdae259..4bf140c7 100644 --- a/tests/codemods/base_codemod_test.py +++ b/tests/codemods/base_codemod_test.py @@ -23,7 +23,7 @@ def setup_method(self): self.file_context = None def run_and_assert(self, tmpdir, input_code, expected): - tmp_file_path = tmpdir / "code.py" + tmp_file_path = Path(tmpdir / "code.py") self.run_and_assert_filepath(tmpdir, tmp_file_path, input_code, expected) def run_and_assert_filepath(self, root, file_path, input_code, expected): diff --git a/tests/codemods/test_remove_unused_imports.py b/tests/codemods/test_remove_unused_imports.py index 43ee5e55..9e675b5c 100644 --- a/tests/codemods/test_remove_unused_imports.py +++ b/tests/codemods/test_remove_unused_imports.py @@ -1,5 +1,7 @@ +from pathlib import Path from core_codemods.remove_unused_imports import RemoveUnusedImports from tests.codemods.base_codemod_test import BaseCodemodTest +from textwrap import dedent class TestRemoveUnusedImports(BaseCodemodTest): @@ -90,6 +92,17 @@ def test_dont_remove_if_noqa_trailing(self, tmpdir): self.run_and_assert(tmpdir, before, before) assert len(self.file_context.codemod_changes) == 0 + def test_dont_remove_if_noqa_trailing_multiline(self, tmpdir): + before = dedent( + """\ + from _pytest.assertion.util import ( # noqa: F401 + format_explanation as _format_explanation, + )""" + ) + + self.run_and_assert(tmpdir, before, before) + assert len(self.file_context.codemod_changes) == 0 + def test_dont_remove_if_pylint_disable(self, tmpdir): before = "import a\nimport b # pylint: disable=W0611\na()" self.run_and_assert(tmpdir, before, before) @@ -101,3 +114,21 @@ def test_dont_remove_if_pylint_disable_next(self, tmpdir): ) self.run_and_assert(tmpdir, before, before) assert len(self.file_context.codemod_changes) == 0 + + def test_ignore_init_files(self, tmpdir): + before = "import a" + tmp_file_path = Path(tmpdir / "__init__.py") + self.run_and_assert_filepath(tmpdir, tmp_file_path, before, before) + assert len(self.file_context.codemod_changes) == 0 + + def test_no_pyling_pragma_in_comment_trailing(self, tmpdir): + before = "import a # bogus: no-pragma" + after = "" + self.run_and_assert(tmpdir, before, after) + assert len(self.file_context.codemod_changes) == 1 + + def test_no_pyling_pragma_in_comment_before(self, tmpdir): + before = "#header\nprint('hello')\n# bogus: no-pragma\nimport a " + after = "#header\nprint('hello')" + self.run_and_assert(tmpdir, before, after) + assert len(self.file_context.codemod_changes) == 1