Skip to content

Commit

Permalink
Implement codemod for walrus-if
Browse files Browse the repository at this point in the history
  • Loading branch information
drdavella committed Sep 20, 2023
1 parent 5bc8065 commit b73d1cc
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 5 deletions.
1 change: 1 addition & 0 deletions ci_tests/test_webgoat_findings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"pixee:python/harden-pyyaml",
"pixee:python/django-debug-flag-on",
"pixee:python/url-sandbox",
"pixee:python/use-walrus-if",
]


Expand Down
2 changes: 2 additions & 0 deletions src/codemodder/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from codemodder.codemods.remove_unnecessary_f_str import RemoveUnnecessaryFStr
from codemodder.codemods.tempfile_mktemp import TempfileMktemp
from codemodder.codemods.requests_verify import RequestsVerify
from codemodder.codemods.use_walrus_if import UseWalrusIf

DEFAULT_CODEMODS = {
DjangoDebugFlagOn,
Expand All @@ -36,6 +37,7 @@
UrlSandbox,
TempfileMktemp,
RequestsVerify,
UseWalrusIf,
}
ALL_CODEMODS = DEFAULT_CODEMODS

Expand Down
5 changes: 1 addition & 4 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def __init__(
# similar when they define their `on_result_found` method.
# Right now this is just to demonstrate a particular use case.
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
if self.filter_by_result(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match):
if self.node_is_selected(original_node):
self.report_change(original_node)
if (attr := getattr(self, "on_result_found", None)) is not None:
# pylint: disable=not-callable
Expand Down
6 changes: 6 additions & 0 deletions src/codemodder/codemods/base_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def filter_by_path_includes_or_excludes(self, pos_to_match):
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_is_selected(self, node) -> bool:
pos_to_match = self.node_position(node)
return self.filter_by_result(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match)

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(self.METADATA_DEPENDENCIES[0], node)
Expand Down
108 changes: 108 additions & 0 deletions src/codemodder/codemods/use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List, Tuple

import libcst as cst
from libcst._position import CodeRange
from libcst import matchers as m
from libcst.metadata import ParentNodeProvider, ScopeProvider

from codemodder.change import Change
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.codemods.api import SemgrepCodemod


class UseWalrusIf(SemgrepCodemod, NameResolutionMixin):
METADATA_DEPENDENCIES = SemgrepCodemod.METADATA_DEPENDENCIES + (
ParentNodeProvider,
ScopeProvider,
)
NAME = "use-walrus-if"
SUMMARY = (
"Replaces multiple expressions involving `if` operator with 'walrus' operator"
)
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
DESCRIPTION = (
"Replaces multiple expressions involving `if` operator with 'walrus' operator"
)

@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern: |
$ASSIGN
if $COND:
$BODY
- metavariable-pattern:
metavariable: $ASSIGN
patterns:
- pattern: $VAR = $RHS
- metavariable-pattern:
metavariable: $COND
patterns:
- pattern: $VAR
- metavariable-pattern:
metavariable: $BODY
pattern: $VAR
- focus-metavariable: $ASSIGN
"""

_modify_next_if: List[Tuple[CodeRange, cst.Assign]]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modify_next_if = []

def add_change(self, position: CodeRange):
self.file_context.codemod_changes.append(
Change(
lineNumber=position.start.line,
description=self.CHANGE_DESCRIPTION,
).to_json()
)

def leave_If(self, original_node, updated_node):
if self._modify_next_if:
position, if_node = self._modify_next_if.pop()
is_name = m.matches(updated_node.test, m.Name())
named_expr = cst.NamedExpr(
target=if_node.targets[0].target,
value=if_node.value,
lpar=[] if is_name else [cst.LeftParen()],
rpar=[] if is_name else [cst.RightParen()],
)
self.add_change(position)
return (
updated_node.with_changes(test=named_expr)
if is_name
else updated_node.with_changes(
test=updated_node.test.with_changes(left=named_expr)
)
)

return original_node

def _is_valid_modification(self, node):
"""
Restricts the kind of modifications we can make to the AST.
This is necessary since the semgrep rule can't fully encode this restriction.
"""
if parent := self.get_metadata(ParentNodeProvider, node):
if gparent := self.get_metadata(ParentNodeProvider, parent):
if (idx := gparent.children.index(parent)) >= 0:
return m.matches(
gparent.children[idx + 1],
m.If(test=(m.Name() | m.Comparison(left=m.Name()))),
)
return False

def leave_Assign(self, original_node, updated_node):
if self.node_is_selected(original_node):
if self._is_valid_modification(original_node):
position = self.node_position(original_node)
self._modify_next_if.append((position, updated_node))
return cst.RemoveFromParent()

return original_node
136 changes: 136 additions & 0 deletions tests/codemods/test_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pytest

from codemodder.codemods.use_walrus_if import UseWalrusIf
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest


class TestUseWalrusIf(BaseSemgrepCodemodTest):
codemod = UseWalrusIf

@pytest.mark.parametrize(
"condition",
[
"is None",
"is not None",
"== 42",
'!= "bar"',
],
)
def test_simple_use_walrus_if(self, tmpdir, condition):
input_code = f"""
val = do_something()
if val {condition}:
do_something_else(val)
"""
expected_output = f"""
if (val := do_something()) {condition}:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_name_only(self, tmpdir):
input_code = """
val = do_something()
if val:
do_something_else(val)
"""
expected_output = """
if val := do_something():
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_multiple(self, tmpdir):
input_code = """
val = do_something()
if val is not None:
do_something_else(val)
foo = hello()
if foo == "bar":
whatever(foo)
"""
# TODO: not sure why libcst isn't preserving empty lines
expected_output = """
if (val := do_something()) is not None:
do_something_else(val)
if (foo := hello()) == "bar":
whatever(foo)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_in_function(self, tmpdir):
"""Make sure this works inside more complex code"""
input_code = """
def foo():
val = do_something()
if val is not None:
do_something_else(val)
"""
expected_output = """
def foo():
if (val := do_something()) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_nested(self, tmpdir):
"""Make sure this works inside more complex code"""
input_code = """
x = do_something()
if x is not None:
y = do_something_else(x)
if y is not None:
bizbaz(x, y)
"""
expected_output = """
if (x := do_something()) is not None:
if (y := do_something_else(x)) is not None:
bizbaz(x, y)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("space", ["", "\n"])
def test_with_whitespace(self, tmpdir, space):
input_code = f"""
val = do_something(){space}
if val is not None:
do_something_else(val)
"""
expected_output = f"""
{space}if (val := do_something()) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_dont_apply_walrus_different_variable(self, tmpdir):
input_code = """
val = do_something()
if foo is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_method_call(self, tmpdir):
input_code = """
val = do_something()
if val.method() is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_call_with_func(self, tmpdir):
input_code = """
val = do_something()
if woot(val) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_expr(self, tmpdir):
input_code = """
val = do_something()
if val + 42 is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)
2 changes: 1 addition & 1 deletion tests/samples/unnecessary_f_str.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
bad = f"hello"
bad = "hello"
good = f"{2+3}"

0 comments on commit b73d1cc

Please sign in to comment.