Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Honor annotations in subprocess-shell-false #259

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions src/codemodder/codemods/check_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import re
from typing import Mapping

import libcst as cst
from libcst import CSTVisitor
from libcst.metadata import ParentNodeProvider
from libcst.metadata.base_provider import ProviderT # noqa: F401
from libcst._nodes.base import CSTNode # noqa: F401

from pylint.utils.pragma_parser import parse_pragma

NOQA_PATTERN = re.compile(r"^#\s*noqa(:\s+[A-Z]+[A-Z0-9]+)?", re.IGNORECASE)


__all__ = ["is_disabled_by_annotations"]


class _GatherCommentNodes(CSTVisitor):
METADATA_DEPENDENCIES = (ParentNodeProvider,)

messages: list[str]

def __init__(
self,
metadata: Mapping[ProviderT, Mapping[CSTNode, object]],
messages: list[str],
) -> None:
self.comments: list[cst.Comment] = []
super().__init__()
self.metadata = metadata
self.messages = messages

def leave_Comment(self, original_node: cst.Comment) -> None:
self.comments.append(original_node)

def _process_simple_statement_line(self, stmt: cst.SimpleStatementLine) -> bool:
# has a trailing comment string anywhere in the node
stmt.body[0].visit(self)
# has a trailing comment string anywhere in the node
if stmt.trailing_whitespace.comment:
self.comments.append(stmt.trailing_whitespace.comment)

for comment in self.comments:
trailing_comment_string = comment.value
if trailing_comment_string and self._noqa_message_match(
trailing_comment_string
):
return True
if trailing_comment_string and self._is_pylint_disable_unused_imports(
trailing_comment_string
):
return True

# has a comment right above it
match stmt:
case cst.SimpleStatementLine(
leading_lines=[
*_,
cst.EmptyLine(comment=cst.Comment(value=comment_string)),
]
):
return self._noqa_message_match(comment_string) or (
self._is_pylint_disable_next_unused_imports(comment_string)
)

return False

def is_disabled_by_linter(self, node: cst.CSTNode) -> bool:
"""
Check if the import has a #noqa or # pylint: disable(-next) comment attached to it.
"""
match self.get_metadata(ParentNodeProvider, node):
case cst.SimpleStatementLine() as stmt:
return self._process_simple_statement_line(stmt)
case cst.Expr() as expr:
match self.get_metadata(ParentNodeProvider, expr):
case cst.SimpleStatementLine() as stmt:
return self._process_simple_statement_line(stmt)
return False

def _noqa_message_match(self, comment: str) -> bool:
if not (match := NOQA_PATTERN.match(comment)):
return False

if match.group(1):
return match.group(1).strip(":").strip() in self.messages

return True

def _is_pylint_disable_unused_imports(self, comment: str) -> bool:
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable" and any(
message in p.messages for message in self.messages
):
return True
except Exception:
pass
return False

def _is_pylint_disable_next_unused_imports(self, comment: str) -> bool:
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable-next" and any(
message in p.messages for message in self.messages
):
return True
except Exception:
pass
return False


def is_disabled_by_annotations(
node: cst.CSTNode,
metadata: Mapping[ProviderT, Mapping[CSTNode, object]],
messages: list[str],
) -> bool:
"""
Check if the import has a #noqa or # pylint: disable(-next) comment attached to it.
"""
visitor = _GatherCommentNodes(metadata, messages)
node.visit(visitor)
return visitor.is_disabled_by_linter(node)
98 changes: 8 additions & 90 deletions src/core_codemods/remove_unused_imports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import re

import libcst as cst
from libcst import CSTVisitor, ensure_type, matchers
from libcst.codemod.visitors import GatherUnusedImportsVisitor
from libcst.metadata import (
PositionProvider,
Expand All @@ -10,16 +7,13 @@
ParentNodeProvider,
)

from pylint.utils.pragma_parser import parse_pragma

from core_codemods.api import SimpleCodemod, Metadata, ReviewGuidance
from codemodder.change import Change
from codemodder.codemods.check_annotations import is_disabled_by_annotations
from codemodder.codemods.transformations.remove_unused_imports import (
RemoveUnusedImportsTransformer,
)

NOQA_PATTERN = re.compile(r"^#\s*noqa", re.IGNORECASE)


class RemoveUnusedImports(SimpleCodemod):
metadata = Metadata(
Expand All @@ -36,6 +30,8 @@ class RemoveUnusedImports(SimpleCodemod):
ParentNodeProvider,
)

IGNORE_ANNOTATIONS = ["unused-import", "F401", "W0611"]

def transform_module_impl(self, tree: cst.Module) -> cst.Module:
# Do nothing in __init__.py files
if self.file_context.file_path.name == "__init__.py":
Expand All @@ -47,7 +43,11 @@ def transform_module_impl(self, tree: cst.Module) -> cst.Module:
for import_alias, importt in gather_unused_visitor.unused_imports:
pos = self.get_metadata(PositionProvider, import_alias)
if self.filter_by_path_includes_or_excludes(pos):
if not self._is_disabled_by_linter(importt):
if not is_disabled_by_annotations(
importt,
self.metadata, # type: ignore
messages=self.IGNORE_ANNOTATIONS,
):
self.file_context.codemod_changes.append(
Change(pos.start.line, self.change_description)
)
Expand All @@ -65,88 +65,6 @@ def filter_by_path_includes_or_excludes(self, pos_to_match) -> bool:
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def _is_disabled_by_linter(self, node: cst.CSTNode) -> bool:
"""
Check if the import has a #noqa or # pylint: disable(-next)=unused_imports comment attached to it.
"""
parent = self.get_metadata(ParentNodeProvider, node)
if parent and matchers.matches(parent, matchers.SimpleStatementLine()):
stmt = ensure_type(parent, cst.SimpleStatementLine)

# has a trailing comment string anywhere in the node
comments_visitor = GatherCommentNodes()
stmt.body[0].visit(comments_visitor)
# has a trailing comment string anywhere in the node
if stmt.trailing_whitespace.comment:
comments_visitor.comments.append(stmt.trailing_whitespace.comment)

for comment in comments_visitor.comments:
trailing_comment_string = comment.value
if trailing_comment_string and NOQA_PATTERN.match(
trailing_comment_string
):
return True
if trailing_comment_string and _is_pylint_disable_unused_imports(
trailing_comment_string
):
return True

# has a comment right above it
if matchers.matches(
stmt,
matchers.SimpleStatementLine(
leading_lines=[
matchers.ZeroOrMore(),
matchers.EmptyLine(comment=matchers.Comment()),
]
),
):
comment_string = stmt.leading_lines[-1].comment.value
if NOQA_PATTERN.match(comment_string):
return True
if comment_string and _is_pylint_disable_next_unused_imports(
comment_string
):
return True
return False


class GatherCommentNodes(CSTVisitor):
def __init__(self) -> None:
self.comments: list[cst.Comment] = []
super().__init__()

def leave_Comment(self, original_node: cst.Comment) -> None:
self.comments.append(original_node)


def match_line(pos, line):
return pos.start.line == line and pos.end.line == line


def _is_pylint_disable_unused_imports(comment: str) -> bool:
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
except Exception:
pass
return False


def _is_pylint_disable_next_unused_imports(comment: str) -> bool:
# If pragma parse fails, ignore
try:
parsed = parse_pragma(comment)
for p in parsed:
if p.action == "disable-next" and (
"unused-import" in p.messages or "W0611" in p.messages
):
return True
except Exception:
pass
return False
10 changes: 10 additions & 0 deletions src/core_codemods/subprocess_shell_false.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import libcst as cst
from libcst import matchers
from libcst.metadata import ParentNodeProvider

from codemodder.codemods.check_annotations import is_disabled_by_annotations
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.codemods.libcst_transformer import NewArg
from core_codemods.api import (
Expand Down Expand Up @@ -31,6 +34,9 @@ class SubprocessShellFalse(SimpleCodemod, NameResolutionMixin):
for func in {"run", "call", "check_output", "check_call", "Popen"}
]

METADATA_DEPENDENCIES = SimpleCodemod.METADATA_DEPENDENCIES + (ParentNodeProvider,)
IGNORE_ANNOTATIONS = ["S603"]

def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
Expand All @@ -44,6 +50,10 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
matchers.Arg(
keyword=matchers.Name("shell"), value=matchers.Name("True")
),
) and not is_disabled_by_annotations(
original_node,
self.metadata, # type: ignore
messages=self.IGNORE_ANNOTATIONS,
):
self.report_change(original_node)
new_args = self.replace_args(
Expand Down
21 changes: 21 additions & 0 deletions tests/codemods/test_subprocess_shell_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,24 @@ def test_exclude_line(self, tmpdir):
expected,
lines_to_exclude=lines_to_exclude,
)

@each_func
def test_has_noqa(self, tmpdir, func):
input_code = (
expected
) = f"""
import subprocess
subprocess.{func}(args, shell=True) # noqa: S603
"""
self.run_and_assert(tmpdir, input_code, expected)

def test_different_noqa_message(self, tmpdir):
input_code = """
import subprocess
subprocess.run(args, shell=True) # noqa: S604
"""
expected = """
import subprocess
subprocess.run(args, shell=False) # noqa: S604
"""
self.run_and_assert(tmpdir, input_code, expected)
Loading