Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement codemod for walrus-if #40

Merged
merged 4 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci_tests/test_webgoat_findings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"pixee:python/harden-pyyaml",
"pixee:python/django-debug-flag-on",
"pixee:python/url-sandbox",
"pixee:python/use-walrus-if",
]


Expand Down
32 changes: 32 additions & 0 deletions integration_tests/test_use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from codemodder.codemods.use_walrus_if import UseWalrusIf
from integration_tests.base_test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)


class TestUseWalrusIf(BaseIntegrationTest):
drdavella marked this conversation as resolved.
Show resolved Hide resolved
codemod = UseWalrusIf
code_path = "tests/samples/use_walrus_if.py"
original_code, _ = original_and_expected_from_code_path(code_path, [])
expected_new_code = """
if (x := foo()) is not None:
print(x)

if y := bar():
print(y)

z = baz()
print(z)


def whatever():
if (b := biz()) == 10:
print(b)
""".lstrip()

expected_diff = "--- \n+++ \n@@ -1,9 +1,7 @@\n-x = foo()\n-if x is not None:\n+if (x := foo()) is not None:\n print(x)\n \n-y = bar()\n-if y:\n+if y := bar():\n print(y)\n \n z = baz()\n@@ -11,6 +9,5 @@\n \n \n def whatever():\n- b = biz()\n- if b == 10:\n+ if (b := biz()) == 10:\n print(b)\n"

num_changes = 3
expected_line_change = 1
change_description = UseWalrusIf.CHANGE_DESCRIPTION
2 changes: 2 additions & 0 deletions src/codemodder/codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from codemodder.codemods.remove_unnecessary_f_str import RemoveUnnecessaryFStr
from codemodder.codemods.tempfile_mktemp import TempfileMktemp
from codemodder.codemods.requests_verify import RequestsVerify
from codemodder.codemods.use_walrus_if import UseWalrusIf

DEFAULT_CODEMODS = {
DjangoDebugFlagOn,
Expand All @@ -36,6 +37,7 @@
UrlSandbox,
TempfileMktemp,
RequestsVerify,
UseWalrusIf,
}
ALL_CODEMODS = DEFAULT_CODEMODS

Expand Down
5 changes: 1 addition & 4 deletions src/codemodder/codemods/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def __init__(
# similar when they define their `on_result_found` method.
# Right now this is just to demonstrate a particular use case.
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call):
pos_to_match = self.node_position(original_node)
if self.filter_by_result(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match):
if self.node_is_selected(original_node):
self.report_change(original_node)
if (attr := getattr(self, "on_result_found", None)) is not None:
# pylint: disable=not-callable
Expand Down
6 changes: 6 additions & 0 deletions src/codemodder/codemods/base_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def filter_by_path_includes_or_excludes(self, pos_to_match):
return any(match_line(pos_to_match, line) for line in self.line_include)
return True

def node_is_selected(self, node) -> bool:
pos_to_match = self.node_position(node)
return self.filter_by_result(
pos_to_match
) and self.filter_by_path_includes_or_excludes(pos_to_match)

def node_position(self, node):
# See https://github.com/Instagram/LibCST/blob/main/libcst/_metadata_dependent.py#L112
return self.get_metadata(self.METADATA_DEPENDENCIES[0], node)
Expand Down
141 changes: 141 additions & 0 deletions src/codemodder/codemods/use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import List, Tuple

import libcst as cst
from libcst._position import CodeRange
from libcst import matchers as m
from libcst.metadata import ParentNodeProvider, ScopeProvider

from codemodder.change import Change
from codemodder.codemods.base_codemod import ReviewGuidance
from codemodder.codemods.utils_mixin import NameResolutionMixin
from codemodder.codemods.api import SemgrepCodemod


class UseWalrusIf(SemgrepCodemod, NameResolutionMixin):
METADATA_DEPENDENCIES = SemgrepCodemod.METADATA_DEPENDENCIES + (
ParentNodeProvider,
ScopeProvider,
)
NAME = "use-walrus-if"
SUMMARY = (
"Replaces multiple expressions involving `if` operator with 'walrus' operator"
)
REVIEW_GUIDANCE = ReviewGuidance.MERGE_AFTER_CURSORY_REVIEW
DESCRIPTION = (
"Replaces multiple expressions involving `if` operator with 'walrus' operator"
)

@classmethod
def rule(cls):
return """
rules:
- patterns:
- pattern: |
$ASSIGN
if $COND:
$BODY
- metavariable-pattern:
metavariable: $ASSIGN
patterns:
- pattern: $VAR = $RHS
- metavariable-pattern:
metavariable: $COND
patterns:
- pattern: $VAR
- metavariable-pattern:
metavariable: $BODY
pattern: $VAR
- focus-metavariable: $ASSIGN
"""

_modify_next_if: List[Tuple[CodeRange, cst.Assign]]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._modify_next_if = []

def add_change(self, position: CodeRange):
self.file_context.codemod_changes.append(
Change(
lineNumber=position.start.line,
description=self.CHANGE_DESCRIPTION,
).to_json()
)

def leave_If(self, original_node, updated_node):
if self._modify_next_if:
position, if_node = self._modify_next_if.pop()
is_name = m.matches(updated_node.test, m.Name())
named_expr = cst.NamedExpr(
target=if_node.targets[0].target,
value=if_node.value,
lpar=[] if is_name else [cst.LeftParen()],
rpar=[] if is_name else [cst.RightParen()],
)
self.add_change(position)
return (
updated_node.with_changes(test=named_expr)
if is_name
else updated_node.with_changes(
test=updated_node.test.with_changes(left=named_expr)
)
)

return original_node

def _is_valid_modification(self, node):
"""
Restricts the kind of modifications we can make to the AST.

This is necessary since the semgrep rule can't fully encode this restriction.
"""
if parent := self.get_metadata(ParentNodeProvider, node):
if gparent := self.get_metadata(ParentNodeProvider, parent):
if (idx := gparent.children.index(parent)) >= 0:
return m.matches(
gparent.children[idx + 1],
m.If(test=(m.Name() | m.Comparison(left=m.Name()))),
)
return False

def leave_Assign(self, original_node, updated_node):
if self.node_is_selected(original_node):
if self._is_valid_modification(original_node):
position = self.node_position(original_node)
self._modify_next_if.append((position, updated_node))
return cst.RemoveFromParent()

return original_node

def leave_SimpleStatementLine(self, original_node, updated_node):
drdavella marked this conversation as resolved.
Show resolved Hide resolved
"""
Preserves the whitespace and comments in the line when all children are removed.

This feels like a bug in libCST but we'll work around it for now.
"""
if not updated_node.body:
trailing_whitespace = (
(
original_node.trailing_whitespace.with_changes(
whitespace=cst.SimpleWhitespace(""),
),
)
if original_node.trailing_whitespace.comment
else ()
)
# NOTE: The effect of this is to preserve the
# whitespace and comments. However, the type expected by
# cst.Module.body is Sequence[Union[SimpleStatementLine, BaseCompoundStatement]].
# So technically this violates the expected return type since we
# are not adding a new SimpleStatementLine but instead just bare
# EmptyLine and Comment nodes.
# A more correct solution would involve transferring any whitespace
# and comments to the subsequent SimpleStatementLine (which
# contains the If statement), but this would require a lot more
# state management to fit within the visitor pattern. We should
# revisit this at some point later.
return cst.FlattenSentinel(
original_node.leading_lines + trailing_whitespace
)

return updated_node
150 changes: 150 additions & 0 deletions tests/codemods/test_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from codemodder.codemods.use_walrus_if import UseWalrusIf
from tests.codemods.base_codemod_test import BaseSemgrepCodemodTest


class TestUseWalrusIf(BaseSemgrepCodemodTest):
codemod = UseWalrusIf

@pytest.mark.parametrize(
"condition",
[
"is None",
"is not None",
"== 42",
'!= "bar"',
],
)
def test_simple_use_walrus_if(self, tmpdir, condition):
input_code = f"""
val = do_something()
if val {condition}:
do_something_else(val)
"""
expected_output = f"""
if (val := do_something()) {condition}:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_name_only(self, tmpdir):
input_code = """
val = do_something()
if val:
do_something_else(val)
"""
expected_output = """
if val := do_something():
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_preserve_comments(self, tmpdir):
input_code = """
val = do_something() # comment
if val is not None: # another comment
do_something_else(val)
"""
expected_output = """
# comment
if (val := do_something()) is not None: # another comment
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_multiple(self, tmpdir):
input_code = """
val = do_something()
if val is not None:
do_something_else(val)

foo = hello()
if foo == "bar":
whatever(foo)
"""
expected_output = """
if (val := do_something()) is not None:
do_something_else(val)

if (foo := hello()) == "bar":
whatever(foo)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_in_function(self, tmpdir):
"""Make sure this works inside more complex code"""
input_code = """
def foo():
val = do_something()
if val is not None:
do_something_else(val)
"""
expected_output = """
def foo():
if (val := do_something()) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

def test_walrus_if_nested(self, tmpdir):
"""Make sure this works inside more complex code"""
input_code = """
x = do_something()
if x is not None:
y = do_something_else(x)
if y is not None:
bizbaz(x, y)
"""
expected_output = """
if (x := do_something()) is not None:
if (y := do_something_else(x)) is not None:
bizbaz(x, y)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("space", ["", "\n"])
def test_with_whitespace(self, tmpdir, space):
input_code = f"""
val = do_something(){space}
if val is not None:
do_something_else(val)
"""
expected_output = f"""
{space}if (val := do_something()) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, expected_output)

@pytest.mark.parametrize("variable", ["foo", "value", "oval"])
def test_dont_apply_walrus_different_variable(self, tmpdir, variable):
input_code = f"""
val = do_something()
if {variable} is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_method_call(self, tmpdir):
input_code = """
val = do_something()
if val.method() is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_call_with_func(self, tmpdir):
input_code = """
val = do_something()
if woot(val) is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)

def test_dont_apply_walrus_expr(self, tmpdir):
input_code = """
val = do_something()
if val + 42 is not None:
do_something_else(val)
"""
self.run_and_assert(tmpdir, input_code, input_code)
16 changes: 16 additions & 0 deletions tests/samples/use_walrus_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
x = foo()
if x is not None:
print(x)

y = bar()
if y:
print(y)

z = baz()
print(z)


def whatever():
b = biz()
if b == 10:
print(b)